package com.wtwd.device.filter;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.cloud.gateway.filter.factory.rewrite.CachedBodyOutputMessage;
import org.springframework.cloud.gateway.support.BodyInserterContext;
import org.springframework.cloud.gateway.support.DefaultServerRequest;
import org.springframework.core.Ordered;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpRequestDecorator;
import org.springframework.stereotype.Component;
import org.springframework.web.reactive.function.BodyInserter;
import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.reactive.function.server.ServerRequest;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

/**
 * @author zdl
 * @Description
 * @create 2021-10-22 15:19
 */
@Component
public class GlobalCacheRequestBodyFilter implements GlobalFilter, Ordered {

    private static final Logger logger = LoggerFactory.getLogger(GlobalCacheRequestBodyFilter.class);

    private static final String CONTENT_TYPE = "Content-Type";

    private static final String CONTENT_TYPE_JSON = "application/json";

    private static final String REQUEST_PREFIX = "\n request:";
    private static final String REQUEST_HEADER = "\n header:";
    private static final String REQUEST_PARAMS = "\n params:";
    private static final String REQUEST_METHOD = "\n method:";
    private static final String REQUEST_API = "\n api:";

    @Override
    public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
        ServerHttpRequest request = exchange.getRequest();
        HttpHeaders headers = request.getHeaders();
        StringBuilder reMsg = new StringBuilder();
        reMsg.append(REQUEST_PREFIX).append(request.getId())
                .append(REQUEST_HEADER).append(headers)
                .append(REQUEST_METHOD).append(request.getMethod())
                .append(REQUEST_API).append(request.getPath())
                ;
        if (request.getMethod() == HttpMethod.GET) {
            reMsg.append(REQUEST_PARAMS).append(request.getQueryParams());
            logger.info(reMsg.toString());
            return chain.filter(exchange);
        }
        String contentType =headers.getFirst(CONTENT_TYPE);
        //判断是否为POST请求
        if (null != contentType &&  contentType.contains(CONTENT_TYPE_JSON)) {
            ServerRequest serverRequest = new DefaultServerRequest(exchange);
            // 读取请求体
            Mono<String> modifiedBody = serverRequest.bodyToMono(String.class)
                    .flatMap(body -> {
                        //记录请求体日志
                        reMsg.append(REQUEST_PARAMS).append(body);
                        logger.info(reMsg.toString());
                        return Mono.just(body);
                    });
            BodyInserter bodyInserter = BodyInserters.fromPublisher(modifiedBody, String.class);
            HttpHeaders header = new HttpHeaders();
            header.putAll(headers);
            header.remove(HttpHeaders.CONTENT_LENGTH);
            CachedBodyOutputMessage outputMessage = new CachedBodyOutputMessage(exchange, header);
            return bodyInserter.insert(outputMessage, new BodyInserterContext())
                    .then(Mono.defer(() -> {
                        ServerHttpRequestDecorator decorator = new ServerHttpRequestDecorator(
                                exchange.getRequest()) {
                            @Override
                            public HttpHeaders getHeaders() {
                                long contentLength = header.getContentLength();
                                HttpHeaders httpHeaders = new HttpHeaders();
                                httpHeaders.putAll(super.getHeaders());
                                if (contentLength > 0) {
                                    httpHeaders.setContentLength(contentLength);
                                } else {
                                    httpHeaders.set(HttpHeaders.TRANSFER_ENCODING, "chunked");
                                }
                                return httpHeaders;
                            }
                            @Override
                            public Flux<DataBuffer> getBody() {
                                return outputMessage.getBody();
                            }
                        };
                        return chain.filter(exchange.mutate().request(decorator).build());
                    }));
        }
        return chain.filter(exchange);
    }

    @Override
    public int getOrder() {
        return -21;
    }
}
